{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7cd73ca8-afd1-41ca-943e-a3598a4a9549",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/user/PycharmProjects/graphormer-pyg/venv/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch_geometric.datasets import MoleculeNet\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.data import Data\n",
    "from torch.utils.data import Subset\n",
    "from torch_geometric.utils.convert import to_networkx\n",
    "from networkx import all_pairs_shortest_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c4c26cb6-f12b-43ea-9482-455cbb4e42d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: rdkit in ./venv/lib/python3.9/site-packages (2023.3.1)\n",
      "Requirement already satisfied: numpy in ./venv/lib/python3.9/site-packages (from rdkit) (1.24.3)\n",
      "Requirement already satisfied: Pillow in ./venv/lib/python3.9/site-packages (from rdkit) (9.5.0)\n",
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.1.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "!pip install rdkit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1d360dec-8745-4fee-adcf-4aeb6c35be95",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ESOL(1128)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = MoleculeNet(root=\"./\", name=\"ESOL\")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9dcd5b5d-54ac-4ffe-903e-c788b55cbec7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch_geometric.nn as tgnn\n",
    "from graphormer.model import Graphormer\n",
    "\n",
    "\n",
    "model = Graphormer(\n",
    "    num_layers=3,\n",
    "    input_node_dim=dataset.num_node_features,\n",
    "    node_dim=128,\n",
    "    input_edge_dim=dataset.num_edge_features,\n",
    "    edge_dim=128,\n",
    "    output_dim=dataset[0].y.shape[1],\n",
    "    n_heads=4,\n",
    "    ff_dim=256,\n",
    "    max_in_degree=5,\n",
    "    max_out_degree=5,\n",
    "    max_path_distance=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cb0b624d-7d7d-423c-acf2-5c1e0d033917",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "test_ids, train_ids = train_test_split([i for i in range(len(dataset))], test_size=0.8, random_state=42)\n",
    "train_loader = DataLoader(Subset(dataset, train_ids), batch_size=1)\n",
    "test_loader = DataLoader(Subset(dataset, test_ids), batch_size=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af514e6d-cfd9-46b0-8725-45b422f9db83",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)\n",
    "loss_functin = nn.L1Loss(reduction=\"sum\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f42e66-2ca7-457d-9218-080bd1400ab0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|█████▋                                                                                                                                  | 38/903 [00:13<02:56,  4.91it/s]"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from torch_geometric.nn.pool import global_mean_pool\n",
    "\n",
    "DEVICE = \"cpu\"\n",
    "\n",
    "model.to(DEVICE)\n",
    "for epoch in range(10):\n",
    "    model.train()\n",
    "    batch_loss = 0.0\n",
    "    for batch in tqdm(train_loader):\n",
    "        batch.to(DEVICE)\n",
    "        y = batch.y\n",
    "        optimizer.zero_grad()\n",
    "        output = global_mean_pool(model(batch), batch.batch)\n",
    "        loss = loss_functin(output, y)\n",
    "        batch_loss += loss.item()\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
    "        optimizer.step()\n",
    "    print(\"TRAIN_LOSS\", batch_loss / len(train_ids))\n",
    "\n",
    "    model.eval()\n",
    "    batch_loss = 0.0\n",
    "    for batch in tqdm(test_loader):\n",
    "        batch.to(DEVICE)\n",
    "        y = batch.y\n",
    "        with torch.no_grad():\n",
    "            output = global_mean_pool(model(batch), batch.batch)\n",
    "            loss = loss_functin(output, y)\n",
    "            \n",
    "        batch_loss += loss.item()\n",
    "\n",
    "    print(\"EVAL LOSS\", batch_loss / len(test_ids))\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06efd4f9-cc83-42e4-be58-34815323b728",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
